
<!DOCTYPE html>
<html>
  <head>
    
<meta charset="utf-8" >

<title>快手bagua使用教程翻译与摘抄 | dragon</title>
<meta name="description" content="邮箱(base64)：MTY5MDMwMjk2M0BxcS5jb20=
">

<meta name="viewport" content="width=device-width, initial-scale=1, maximum-scale=1, user-scalable=no">
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/animate.css/3.7.0/animate.min.css">

<link rel="stylesheet" href="https://use.fontawesome.com/releases/v5.7.2/css/all.css" integrity="sha384-fnmOCqbTlWIlj8LyTjo7mOUStjsKC4pOpQbqyi7RrhN7udi9RwhKkMHpvLbHG9Sr" crossorigin="anonymous">
<link rel="shortcut icon" href="https://dragonfive.gitee.io//favicon.ico?v=1740893463017">
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/KaTeX/0.10.0/katex.min.css">
<link rel="stylesheet" href="https://dragonfive.gitee.io//styles/main.css">



<script src="https://cdn.jsdelivr.net/npm/vue/dist/vue.js"></script>
<script src="//cdn.jsdelivr.net/gh/highlightjs/cdn-release@11.5.1/build/highlight.min.js"></script>



  </head>
  <body>
    <div id="app" class="main">
      <div class="site-header-container">
  <div class="site-header">
    <div class="left">
      <a href="https://dragonfive.gitee.io/">
        <img class="avatar" src="https://dragonfive.gitee.io//images/avatar.png?v=1740893463017" alt="" width="32px" height="32px">
      </a>
      <a href="https://dragonfive.gitee.io/">
        <h1 class="site-title">dragon</h1>
      </a>
    </div>
    <div class="right">
      <transition name="fade">
        <i class="icon" :class="{ 'icon-close-outline': menuVisible, 'icon-menu-outline': !menuVisible }" @click="menuVisible = !menuVisible"></i>
      </transition>
    </div>
  </div>
</div>

<transition name="fade">
  <div class="menu-container" style="display: none;" v-show="menuVisible">
    <div class="menu-list">
      
        
          <a href="/" class="menu purple-link">
            首页
          </a>
        
      
        
          <a href="/archives" class="menu purple-link">
            归档
          </a>
        
      
        
          <a href="/tags" class="menu purple-link">
            标签
          </a>
        
      
        
          <a href="/post/about" class="menu purple-link">
            关于
          </a>
        
      
    </div>
  </div>
</transition>


      <div class="content-container">
        <div class="post-detail">
          
          <h2 class="post-title">快手bagua使用教程翻译与摘抄</h2>
          <div class="post-info post-detail-info">
            <span><i class="icon-calendar-outline"></i> 2021-09-01</span>
            
              <span>
                <i class="icon-pricetags-outline"></i>
                
                  <a href="https://dragonfive.gitee.io/tag/xun-lian-kuang-jia/">
                    训练框架
                    
                      ，
                    
                  </a>
                
                  <a href="https://dragonfive.gitee.io/tag/WzibKNMac/">
                    深度学习
                    
                  </a>
                
              </span>
            
          </div>
          <div class="post-content" v-pre>
            <p>翻译自<a href="https://bagua-tutorials.kwai-seattle.com/introduction">八卦文档</a></p>
<p>八卦已经整合的原语包括</p>
<ul>
<li>集中式同步通信（AllReduce）</li>
<li>去中心化同步通讯</li>
<li>低精度通信</li>
</ul>
<p>其有效性已经在各种场景和模型中得到验证，包括ImageNet上的VGG和ResNet，Bert Large，以及快手等多个大规模工业应用：</p>
<ul>
<li>支持数十个TB参数模型训练的推荐系统，</li>
<li>超过 10 亿个图像/视频的视频/图像理解，</li>
<li>ASR 与 TB 级数据集，</li>
<li>等等。</li>
</ul>
<h1 id="使用举例">使用举例</h1>
<p>使用bagua类似于使用其他分布式训练库，如 PyTorch DDP 和 Horovod。</p>
<pre><code class="language-cpp">def main():
    args = parse_args()
    # define your model and optimizer
    model = MyNet().to(args.device)
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
    # transform to Bagua wrapper
    from bagua.torch_api.algorithms import gradient_allreduce
    model = model.with_bagua(
        [optimizer], gradient_allreduce.GradientAllReduceAlgorithm()
    )

    # train the model over the dataset
    for epoch in range(args.epochs):
        for b_idx, (inputs, targets) in enumerate(train_loader):
            outputs = model(inputs)
            loss = torch.nn.CrossEntropyLoss(outputs, targets)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

</code></pre>
<p>计算使用的是pytorch, 只需要包上一个通信算法原语就可以了。</p>
<h1 id="通信算法">通信算法</h1>
<p>八卦依靠分布式学习算法的多样性而蓬勃发展。 系统的极大灵活性使得可以平滑地结合各种SOTA算法，同时在执行过程中提供性能的自动优化。 对于最终用户，八卦提供了广泛的算法选择，她可以轻松地尝试执行她的任务。 对于算法开发人员来说，八卦是一个游乐场，在那里她可以只专注于算法本身（例如，逻辑和控制），而无需在不同的算法之间重新发明轮子（例如，通信原语和系统优化）。</p>
<h2 id="梯度-allreduce">梯度 allreduce</h2>
<h3 id="总览">总览</h3>
<p>Gradient AllReduce 算法是一种流行的同步数据并行分布式算法。 它是大多数现有解决方案中实现的算法，比如  <a href="https://pytorch.org/tutorials/intermediate/ddp_tutorial.html">PyTorch DistributedDataParallel</a>, <a href="https://horovod.ai/">Horovod</a>, 以及<a href="https://www.tensorflow.org/guide/distributed_training">TensorFlow Mirrored Strategy</a>。</p>
<h3 id="算法">算法</h3>
<p>使用此算法，每个worker在每次迭代中执行以下步骤。</p>
<ul>
<li>使用小批量计算梯度。</li>
<li>使用 AllReduce 集合计算所有worker的梯度平均值。</li>
<li>用平均梯度更新模型。</li>
</ul>
<p>在八卦中，这个算法是通过 GradientAllReduce 算法类来支持的。 八卦中 GradientAllReduce 实现的默认性能应该与 PyTorch DDP 相当，并且在大多数情况下比 Horovod 更快。 八卦支持额外的优化，例如可以在实例化 GradientAllReduce 类时配置的分层通信。 在某些情况下，例如当机器间网络成为瓶颈时，它们可以使八卦比其他实现更快。</p>
<h3 id="举例">举例</h3>
<p>运行 Gradient AllReduce 的完整示例可以在带有 --algorithm gradient_allreduce 命令行参数的 <a href="https://github.com/BaguaSys/examples/blob/main/benchmark/synthetic_benchmark.py">八卦示例</a>中找到。</p>
<p>您需要初始化八卦算法（请参阅 <a href="https://bagua.readthedocs.io/en/latest/autoapi/bagua/torch_api/algorithms/gradient_allreduce/index.html">API 文档</a>了解您可以自定义哪些参数）：</p>
<pre><code class="language-python">from bagua.torch_api.algorithms import gradient_allreduce
algorithm = gradient_allreduce.GradientAllReduceAlgorithm()
</code></pre>
<p>然后用以下方法装饰您的模型：</p>
<pre><code class="language-python">model = model.with_bagua([optimizer], algorithm)
</code></pre>
<h2 id="bytegrad">bytegrad</h2>
<h3 id="总览-2">总览</h3>
<p>大规模分布式训练需要大量的通信成本，对于大型模型尤其如此。 例如，在使用 AllReduce 同步梯度的传统同步分布式设置中（许多库都是这种情况，比如  <a href="https://github.com/horovod/horovod">Horovod</a> 和 <a href="https://pytorch.org/tutorials/intermediate/ddp_tutorial.html">PyTorch DDP</a>), 在训练过程的每次迭代中，需要在每个 worker 上发送和接收大小等于模型大小的梯度。 这样的通信成本很快成为很多场景下的训练瓶颈。</p>
<p>有很多关于如何应用模型/梯度压缩来节省这种通信成本的<a href="https://awesomeopensource.com/project/chester256/Model-Compression-Papers?categoryPage=21">现有论文</a>。 八卦提供了一种内置的梯度压缩算法，称为 ByteGrad，它在通信之前将梯度浮点数压缩到 8bit 字节。 这节省了原始成本的 3/4。 它通过优化的 CUDA 内核和分层通信实现了高精度的 min-max 量化算子。 这使得它比现有框架中(比如  <a href="https://pytorch.org/docs/stable/ddp_comm_hooks.html#powersgd-communication-hook">PyTorch PowerSGD</a>)的其他压缩实现快得多（在我们的基准测试中大约快 50%）。并且在相同的时期数 ByteGrad 在大多数任务上收敛近似于全精度算法。</p>
<h3 id="算法-2">算法</h3>
<p>ByteGrad 在每次迭代中执行以下步骤。 假设我们有 m 个节点，每个节点有 n 个 GPU。</p>
<ol>
<li>为所有<span class="katex"><span class="katex-mathml"><math><semantics><mrow><mi>i</mi><mo separator="true">,</mo><mi>j</mi></mrow><annotation encoding="application/x-tex">i,j</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.85396em;vertical-align:-0.19444em;"></span><span class="mord mathdefault">i</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord mathdefault" style="margin-right:0.05724em;">j</span></span></span></span> 计算第 i 个节点的第 j 个 GPU 上的梯度 <span class="katex"><span class="katex-mathml"><math><semantics><mrow><msub><mi>g</mi><mrow><mi>i</mi><mo separator="true">,</mo><mi>j</mi></mrow></msub></mrow><annotation encoding="application/x-tex">g_{i,j}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.716668em;vertical-align:-0.286108em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right:0.03588em;">g</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.311664em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight">i</span><span class="mpunct mtight">,</span><span class="mord mathdefault mtight" style="margin-right:0.05724em;">j</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span></span></span></span></li>
<li>每个节点上的第一个 GPU 做一个 reduce 操作来计算同一节点上所有 GPU 的梯度的平均值，定义为第 i 个节点的 <span class="katex"><span class="katex-mathml"><math><semantics><mrow><msub><mi>G</mi><mi>i</mi></msub></mrow><annotation encoding="application/x-tex">G_i</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathdefault">G</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathdefault mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></li>
<li>第 i 个节点上的第一个 GPU 使用量化函数 <span class="katex"><span class="katex-mathml"><math><semantics><mrow><mi>Q</mi><mo>(</mo><mo>⋅</mo><mo>)</mo><mi mathvariant="normal">：</mi><mi>Q</mi><mo>(</mo><msub><mi>G</mi><mi>i</mi></msub><mo>)</mo></mrow><annotation encoding="application/x-tex">Q(⋅)：Q(G_i )</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathdefault">Q</span><span class="mopen">(</span><span class="mord">⋅</span><span class="mclose">)</span><span class="mord cjk_fallback">：</span><span class="mord mathdefault">Q</span><span class="mopen">(</span><span class="mord"><span class="mord mathdefault">G</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathdefault mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span> 对所有 i 量化梯度 <span class="katex"><span class="katex-mathml"><math><semantics><mrow><msub><mi>G</mi><mi>i</mi></msub></mrow><annotation encoding="application/x-tex">G_i</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.83333em;vertical-align:-0.15em;"></span><span class="mord"><span class="mord mathdefault">G</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathdefault mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span>。 然后每个节点在节点之间交换量化版本，使得每个节点都有所有 <span class="katex"><span class="katex-mathml"><math><semantics><mrow><mi>Q</mi><mo>(</mo><msub><mi>G</mi><mi>i</mi></msub><mo>)</mo></mrow><annotation encoding="application/x-tex">Q(G_i )</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathdefault">Q</span><span class="mopen">(</span><span class="mord"><span class="mord mathdefault">G</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathdefault mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span> 的平均值</li>
<li>每个节点上的第一个 GPU 将所有 <span class="katex"><span class="katex-mathml"><math><semantics><mrow><mi>Q</mi><mo>(</mo><msub><mi>G</mi><mi>i</mi></msub><mo>)</mo></mrow><annotation encoding="application/x-tex">Q(G_i)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathdefault">Q</span><span class="mopen">(</span><span class="mord"><span class="mord mathdefault">G</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathdefault mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span> 的平均值广播给同一节点上的所有其他 GPU，并且所有 worker 上的所有 GPU 都使用此量化平均值来更新模型</li>
</ol>
<p>量化函数 Q(⋅) 计算其输入的最小值 x 和最大值 y，并将 [x,y] 拆分为均匀间隔的 256 个区间。 然后用一个 8 位整数表示其输入的每个元素，表示原始元素在哪个区间。</p>
<h3 id="用法举例">用法举例</h3>
<p>您需要初始化八卦算法（请参阅 <a href="https://bagua.readthedocs.io/en/latest/autoapi/bagua/torch_api/algorithms/bytegrad/index.html">API 文档</a>了解您可以自定义哪些参数）：</p>
<pre><code class="language-python">from bagua.torch_api.algorithms import bytegrad
algorithm = bytegrad.ByteGradAlgorithm()
model = model.with_bagua([optimizer], algorithm)
</code></pre>
<h2 id="去中心化sgd">去中心化SGD</h2>
<h3 id="去中心化训练概述">去中心化训练概述</h3>
<p>Decentralized SGD 是一种数据并行的分布式学习算法，它消除了所有worker之间集中式全局模型的要求，这使得它在通信模式上与基于 Allreduce 或基于参数服务器的算法有很大不同。 使用去中心化 SGD，每个 worker 只需要与一个或几个特定的 worker 交换数据，而不是全局聚合数据。 因此，去中心化通信的通信连接数比 Allreduce 少得多，通信开销比 Parameter Server 更均衡。 尽管去中心化 SGD 可能会导致每个worker的模型不同，但理论上已经证明，去中心化 SGD 算法的收敛速度与其中心化对应版本相同。 您可以在我们的 <a href="https://arxiv.org/abs/1705.09056">论文</a>中找到关于去中心化 SGD 的详细分析 。</p>
<h3 id="去中心化训练算法">去中心化训练算法</h3>
<p>目前，不时有许多分散的训练算法被提出。 这些令人惊叹的作品集中在分散训练的不同方面，如<strong>peer选择、数据压缩、异步</strong>等，并提供了许多有前景的见解。 到目前为止，八卦已经整合了两种基本的去中心化算法，即去中心化 SGD 和低精度去中心化 SGD。 凭借八卦对去中心化的自动系统支持，我们预计在不久的将来会实现越来越多的去中心化算法。</p>
<h3 id="decentralized-sgd">Decentralized SGD</h3>
<p>现在我们将描述在八卦中实现的去中心化 SGD 算法。 假设worker数量为n，worker i上的模型参数为<span class="katex"><span class="katex-mathml"><math><semantics><mrow><msup><mi>x</mi><mrow><mo>(</mo><mi>i</mi><mo>)</mo></mrow></msup><mo separator="true">,</mo><mi>i</mi><mo>∈</mo><mrow><mn>0</mn><mo separator="true">,</mo><mi mathvariant="normal">.</mi><mi mathvariant="normal">.</mi><mi mathvariant="normal">.</mi><mo separator="true">,</mo><mi>n</mi><mi mathvariant="normal">−</mi><mn>1</mn></mrow></mrow><annotation encoding="application/x-tex">x^{(i)},i∈{0,...,n−1}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1.0824399999999998em;vertical-align:-0.19444em;"></span><span class="mord"><span class="mord mathdefault">x</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8879999999999999em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mathdefault mtight">i</span><span class="mclose mtight">)</span></span></span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord mathdefault">i</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">∈</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.8388800000000001em;vertical-align:-0.19444em;"></span><span class="mord"><span class="mord">0</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord">.</span><span class="mord">.</span><span class="mord">.</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord mathdefault">n</span><span class="mord">−</span><span class="mord">1</span></span></span></span></span>。 每个worker 都能够直接从任何其他worker发送或接收数据。 在每次迭代 t 中，算法重复以下步骤：</p>
<ol>
<li>每个worker i 计算迭代的局部梯度 <span class="katex"><span class="katex-mathml"><math><semantics><mrow><msubsup><mi>g</mi><mi>t</mi><mrow><mo>(</mo><mi>i</mi><mo>)</mo></mrow></msubsup></mrow><annotation encoding="application/x-tex">g_t^{(i)}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1.2905559999999998em;vertical-align:-0.24575599999999992em;"></span><span class="mord"><span class="mord mathdefault" style="margin-right:0.03588em;">g</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.0448em;"><span style="top:-2.454244em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathdefault mtight">t</span></span></span><span style="top:-3.2197999999999998em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mathdefault mtight">i</span><span class="mclose mtight">)</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.24575599999999992em;"><span></span></span></span></span></span></span></span></span></span>.</li>
<li>将本地模型与其选定的同伴的模型进行平均（表示为 <span class="katex"><span class="katex-mathml"><math><semantics><mrow><msubsup><mi>x</mi><mi>t</mi><mrow><mo>(</mo><mi>j</mi><mo>)</mo></mrow></msubsup></mrow><annotation encoding="application/x-tex">x_t^{(j)}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1.2905559999999998em;vertical-align:-0.24575599999999992em;"></span><span class="mord"><span class="mord mathdefault">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.0448em;"><span style="top:-2.454244em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathdefault mtight">t</span></span></span><span style="top:-3.2197999999999998em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mathdefault mtight" style="margin-right:0.05724em;">j</span><span class="mclose mtight">)</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.24575599999999992em;"><span></span></span></span></span></span></span></span></span></span>), <span class="katex"><span class="katex-mathml"><math><semantics><mrow><msubsup><mi>x</mi><mfrac><mrow><mi>t</mi><mo>+</mo><mn>1</mn></mrow><mn>2</mn></mfrac><mrow><mo>(</mo><mi>i</mi><mo>)</mo></mrow></msubsup><mo>=</mo><mfrac><mrow><mi>t</mi><mo>(</mo><mi>i</mi><mo>)</mo><mi mathvariant="normal">​</mi><mo>+</mo><mi>x</mi><mi>t</mi><mo>(</mo><mi>j</mi><mo>)</mo></mrow><mn>2</mn></mfrac></mrow><annotation encoding="application/x-tex">x_{\frac{t+1}{2}}^{(i)}=\frac{t(i)​+xt(j)}{2}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1.6918199999999999em;vertical-align:-0.6470199999999999em;"></span><span class="mord"><span class="mord mathdefault">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.0448em;"><span style="top:-2.59378em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight"><span class="mopen nulldelimiter sizing reset-size3 size6"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8443142857142857em;"><span style="top:-2.656em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mtight">2</span></span></span></span><span style="top:-3.2255000000000003em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line mtight" style="border-bottom-width:0.049em;"></span></span><span style="top:-3.384em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mathdefault mtight">t</span><span class="mbin mtight">+</span><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.344em;"><span></span></span></span></span></span><span class="mclose nulldelimiter sizing reset-size3 size6"></span></span></span></span></span><span style="top:-3.5198em;margin-right:0.05em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mathdefault mtight">i</span><span class="mclose mtight">)</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.6470199999999999em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1.355em;vertical-align:-0.345em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.01em;"><span style="top:-2.6550000000000002em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">2</span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.485em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight">t</span><span class="mopen mtight">(</span><span class="mord mathdefault mtight">i</span><span class="mclose mtight">)</span><span class="mord mtight">​</span><span class="mbin mtight">+</span><span class="mord mathdefault mtight">x</span><span class="mord mathdefault mtight">t</span><span class="mopen mtight">(</span><span class="mord mathdefault mtight" style="margin-right:0.05724em;">j</span><span class="mclose mtight">)</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.345em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span>.</li>
<li>使用局部梯度更新平均模型.<span class="katex"><span class="katex-mathml"><math><semantics><mrow><msubsup><mi>x</mi><mrow><mi>t</mi><mo>+</mo><mn>1</mn></mrow><mrow><mo>(</mo><mi>i</mi><mo>)</mo></mrow></msubsup><mo>=</mo><msubsup><mi>x</mi><mfrac><mrow><mi>t</mi><mo>+</mo><mn>1</mn></mrow><mn>2</mn></mfrac><mrow><mo>(</mo><mi>i</mi><mo>)</mo></mrow></msubsup><mo>−</mo><mi>γ</mi><msubsup><mi>g</mi><mi>t</mi><mrow><mo>(</mo><mi>i</mi><mo>)</mo></mrow></msubsup></mrow><annotation encoding="application/x-tex">x_{t+1}^{(i)}=x_{\frac{t+1}{2}}^{(i)}-\gamma g_t^{(i)}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1.3694389999999999em;vertical-align:-0.3246389999999999em;"></span><span class="mord"><span class="mord mathdefault">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.0448em;"><span style="top:-2.433692em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathdefault mtight">t</span><span class="mbin mtight">+</span><span class="mord mtight">1</span></span></span></span><span style="top:-3.2198em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mathdefault mtight">i</span><span class="mclose mtight">)</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.3246389999999999em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1.6918199999999999em;vertical-align:-0.6470199999999999em;"></span><span class="mord"><span class="mord mathdefault">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.0448em;"><span style="top:-2.59378em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight"><span class="mopen nulldelimiter sizing reset-size3 size6"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8443142857142857em;"><span style="top:-2.656em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mtight">2</span></span></span></span><span style="top:-3.2255000000000003em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line mtight" style="border-bottom-width:0.049em;"></span></span><span style="top:-3.384em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mathdefault mtight">t</span><span class="mbin mtight">+</span><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.344em;"><span></span></span></span></span></span><span class="mclose nulldelimiter sizing reset-size3 size6"></span></span></span></span></span><span style="top:-3.5198em;margin-right:0.05em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mathdefault mtight">i</span><span class="mclose mtight">)</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.6470199999999999em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1.2905559999999998em;vertical-align:-0.24575599999999992em;"></span><span class="mord mathdefault" style="margin-right:0.05556em;">γ</span><span class="mord"><span class="mord mathdefault" style="margin-right:0.03588em;">g</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.0448em;"><span style="top:-2.454244em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathdefault mtight">t</span></span></span><span style="top:-3.2197999999999998em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mopen mtight">(</span><span class="mord mathdefault mtight">i</span><span class="mclose mtight">)</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.24575599999999992em;"><span></span></span></span></span></span></span></span></span></span></li>
</ol>
<p>在步骤 2 中，我们采用一种策略为每次迭代中的每个 worker 选择一个 peer，这样所有 worker 都正确配对并且数据交换是有效的，因为每个 worker 可以在迭代之间与不同的 peer 交换数据。 简而言之，我们的策略将worker 平均分成两组，并在两组之间动态配对工worker，每次迭代都不同。</p>
<blockquote>
<p>点评，最后以谁的参数作为最终的ckpt呢？</p>
</blockquote>
<h3 id="通信开销">通信开销</h3>
<p>去中心化 SGD 的通信开销与网络程度高度相关，即一个worker与其他worker的连接数。 不同的拓扑或策略会导致不同程度的网络。 很明显，我们之前描述的Decentralized SGD算法的网络度为1。因此，在每次迭代中，一个worker只需要与一个worker建立一个连接来交换模型大小的1倍。 我们比较了不同通信模式在最繁忙节点的延迟和带宽方面的通信复杂性。</p>
<table>
<thead>
<tr>
<th style="text-align:center">算法</th>
<th style="text-align:center">时延复杂度</th>
<th style="text-align:center">带宽复杂度</th>
</tr>
</thead>
<tbody>
<tr>
<td style="text-align:center">Allreduce (Ring)</td>
<td style="text-align:center">O(n)</td>
<td style="text-align:center">O(1)</td>
</tr>
<tr>
<td style="text-align:center">ps</td>
<td style="text-align:center">O(1)</td>
<td style="text-align:center">O(n)</td>
</tr>
<tr>
<td style="text-align:center">八卦去中心化SGD</td>
<td style="text-align:center">O(1)</td>
<td style="text-align:center">O(1)</td>
</tr>
</tbody>
</table>
<h3 id="用法举例-2">用法举例</h3>
<p>您需要初始化八卦算法（请参阅 <a href="https://bagua.readthedocs.io/en/latest/autoapi/bagua/torch_api/algorithms/decentralized/index.html">API 文档</a>了解您可以自定义哪些参数）：</p>
<pre><code class="language-python">from bagua.torch_api.algorithms import decentralized algorithm = decentralized.DecentralizedAlgorithm()
model = model.with_bagua([optimizer], algorithm)
</code></pre>
<h3 id="其他的算法">其他的算法</h3>
<p>八卦提供的其他的算法没什么好说的，可以参考<a href="https://bagua-tutorials.kwai-seattle.com/algorithms/">原文</a></p>
<p>八卦还支持<a href="https://baguasys.github.io/tutorials/elastic-training/index.html">弹性训练</a>, 看上去是保存中间状态实现的。</p>

          </div>
        </div>

        
          <div class="next-post">
            <a class="purple-link" href="https://dragonfive.gitee.io/post/kuai-shou-de-ba-gua-bagua-lun-wen-fan-yi-yu-shang-xi/">
              <h3 class="post-title">
                下一篇：快手的八卦：BAGUA: Scaling up Distributed Learning with System Relaxations论文翻译与赏析
              </h3>
            </a>
          </div>
          
      </div>

      

      <div class="site-footer">
  <div class="slogan">邮箱(base64)：MTY5MDMwMjk2M0BxcS5jb20=
</div>
  <div class="social-container">
    
      
        <a href="https://github.com/DragonFive" target="_blank">
          <i class="fab fa-github"></i>
        </a>
      
    
      
    
      
    
      
    
      
    
  </div>
  Powered by <a href="https://github.com/getgridea/gridea" target="_blank">Gridea</a> | <a class="rss" href="https://dragonfive.gitee.io//atom.xml" target="_blank">RSS</a>
</div>


    </div>
    <script type="application/javascript">

hljs.initHighlightingOnLoad()

var app = new Vue({
  el: '#app',
  data: {
    menuVisible: false,
  },
})

</script>




  </body>
</html>
